Fix #20322 — Weight save/load fails when model uses a custom subclass of a built-in layer (e.g. LSTM)#22360
Conversation
…state serialization The _save_container_state and _load_container_state functions used class names to construct HDF5 paths for container layers. This meant saving weights from a custom LSTM subclass (e.g. MyCustomLSTM) and loading into a vanilla LSTM model would fail even when layer names matched, because the HDF5 paths differed (my_custom_lstm vs lstm). Use saveable.name (the layer name) instead of the class name for topology-based path matching. Add backward-compatible fallback to class-name paths so files saved with the old format still load. Fixes keras-team#20322
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical issue where Keras model weight saving and loading would fail or silently skip weights when models incorporated custom subclasses of built-in layers, such as Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a valuable fix for model weight saving and loading when using custom subclasses of built-in layers, improving compatibility by prioritizing the layer's name over its class for key generation and including a fallback for backward compatibility. However, it introduces a critical path traversal vulnerability by using unsanitized layer names to construct file paths during model saving and loading, which could allow a malicious model to read or write arbitrary files on the host system. Sanitizing layer names before using them as path components is strongly recommended. Additionally, there's a potential bug in the new _store_has_path helper function that could affect sharded weights, and the new test case could be more comprehensive by covering the reverse scenario.
| if hasattr(saveable, "name") and isinstance(saveable.name, str): | ||
| return saveable.name | ||
| return naming.to_snake_case(saveable.__class__.__name__) |
There was a problem hiding this comment.
The _get_container_item_name function returns saveable.name without any sanitization. This name is then used to construct file paths for saving and loading model assets via DiskIOStore. Since saveable.name can be controlled by a malicious model configuration file (config.json), an attacker can use path traversal sequences (e.g., ../) or absolute paths to read or write files outside the intended temporary directory. For example, a layer named ../../../../etc/passwd would cause DiskIOStore.get to return /etc/passwd, which is then passed to the layer's load_assets method, potentially leading to an arbitrary file read.
| if hasattr(saveable, "name") and isinstance(saveable.name, str): | |
| return saveable.name | |
| return naming.to_snake_case(saveable.__class__.__name__) | |
| if hasattr(saveable, "name") and isinstance(saveable.name, str): | |
| return naming.to_snake_case(saveable.name) | |
| return naming.to_snake_case(saveable.__class__.__name__) |
| def _store_has_path(weights_store, path): | ||
| """Check if a path exists in the weights store.""" | ||
| if weights_store is None or not path: | ||
| return False | ||
| if isinstance(weights_store, H5IOStore): | ||
| return path in weights_store.h5_file | ||
| if isinstance(weights_store, NpzIOStore): | ||
| return path in weights_store.contents | ||
| return True |
There was a problem hiding this comment.
The implementation of _store_has_path has a couple of potential issues that could lead to incorrect behavior, especially with sharded weights.
-
Incorrect type checking order:
ShardedH5IOStoreis a subclass ofH5IOStore. The current checkisinstance(weights_store, H5IOStore)will evaluate toTruefor aShardedH5IOStoreinstance, causing it to execute logic that is incorrect for sharded stores as it only checks the current shard. The check forShardedH5IOStoreshould be performed before the check forH5IOStore. -
Unsafe fallback: The function returns
Trueas a fallback for any store type that is notH5IOStoreorNpzIOStore. This is unsafe because it assumes the path exists in any unknown or unhandled store type. It would be safer to returnFalse.
Here is a suggested implementation that addresses these points:
| def _store_has_path(weights_store, path): | |
| """Check if a path exists in the weights store.""" | |
| if weights_store is None or not path: | |
| return False | |
| if isinstance(weights_store, H5IOStore): | |
| return path in weights_store.h5_file | |
| if isinstance(weights_store, NpzIOStore): | |
| return path in weights_store.contents | |
| return True | |
| def _store_has_path(weights_store, path): | |
| """Check if a path exists in the weights store.""" | |
| if weights_store is None or not path: | |
| return False | |
| # `ShardedH5IOStore` must be checked before `H5IOStore` due to inheritance. | |
| if isinstance(weights_store, ShardedH5IOStore): | |
| weight_map = weights_store.sharding_config["weight_map"] | |
| # The path in the weight map is typically `/{path}/vars`. | |
| return path in weight_map or f"/{path}/vars" in weight_map | |
| if isinstance(weights_store, H5IOStore): | |
| return path in weights_store.h5_file | |
| if isinstance(weights_store, NpzIOStore): | |
| return path in weights_store.contents | |
| return False |
| # Verify predictions match | ||
| x = np.random.random((1, 10, 1)).astype("float32") | ||
| self.assertAllClose(model_a(x), model_b(x)) |
There was a problem hiding this comment.
The test docstring mentions testing loading 'and vice versa', but the test currently only covers saving from a custom subclass and loading into the base class. To make the test more comprehensive and match its description, consider adding a test for the other direction: saving from the base LSTM and loading into the CustomLSTM subclass.
# Verify predictions match
x = np.random.random((1, 10, 1)).astype("float32")
self.assertAllClose(model_a(x), model_b(x))
# Test the other direction: save from vanilla, load into subclass
temp_filepath_2 = os.path.join(
self.get_temp_dir(), "vanilla_lstm.weights.h5"
)
# Re-build model_a to reset its weights before loading
inputs_a_2 = keras.Input(shape=(10, 1))
lstm_a_2 = CustomLSTM(32, name="my_lstm")
dense_a_2 = keras.layers.Dense(1, name="output")
model_a_2 = keras.Model(
inputs_a_2, dense_a_2(lstm_a_2(inputs_a_2)), name="model_a"
)
model_b.save_weights(temp_filepath_2)
model_a_2.load_weights(temp_filepath_2)
# Verify predictions match
self.assertAllClose(model_a_2(x), model_b(x))
Fixes: #20322
This pull request improves the robustness and compatibility of Keras model weight saving and loading, especially when dealing with custom layer subclasses that share names with built-in layers. The changes ensure that weights can be reliably transferred between models with matching layer names, even if the layer classes differ, and add comprehensive tests to cover this scenario.
Core logic improvements:
_get_container_item_namehelper to consistently use the layer'snamefor topology-based matching when saving/loading weights, with fallback to the class name for unnamed saveables. This ensures compatibility between custom and base class layers with the same name._save_container_stateand_load_container_stateto use the new naming logic, and in_load_container_state, added a fallback mechanism: if the new name-based path isn't found in the weights store, it tries the legacy class-name-based path for backward compatibility. [1] [2]Testing improvements:
test_custom_subclass_weight_loading) to verify that weights saved from a custom subclassed layer can be loaded into a model with the base class (and vice versa) when the layer names match, addressing issue Loading weights into custom LSTM layer fails: Layer 'lstm_cell' expected 3 variables, but received 0 variables during loading. Expected: ['kernel', 'recurrent_kernel', 'bias'] #20322.Problem
When saving weights, Keras walks the layer tree and stores each sublayer keyed by
to_snake_case(cls.__name__). If the saved model usesclass CustomLSTM(LSTM), the key written iscustom_lstm. When loading into a model that uses the baseLSTMclass (same layer name, same weights), the key expected islstm— mismatch, and weights are silently skipped.The same problem occurs in reverse: saving with the base class and loading into a subclass.
Root Cause
The original code comment acknowledged this explicitly but chose class-name-based keys to avoid autogenerated-name drift between instances. However this choice breaks cross-class weight transfer even when the layer
names are identical and the weight shapes match exactly.Fix
Introduce
_get_container_item_name(saveable): prefersaveable.name(the user-assigned or Keras-deduped string name) as the key, falling back to the class-name approach only when no string name is available. On load, add backward-compatible fallback: if the name-based path doesn't exist in the store, retry with the old class-name-based path. This preserves compatibility with weights files saved before this fix.Files Changed
keras/src/saving/saving_lib.py—_get_container_item_name(),_store_has_path(), updated_save_container_stateand_load_container_statekeras/src/saving/saving_lib_test.py— regression test: save fromCustomLSTM, load into vanillaLSTM, assert predictions match